{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mnist分类任务：\n",
    "\n",
    "- 网络基本构建与训练方法，常用函数解析\n",
    "\n",
    "- torch.nn.functional模块\n",
    "\n",
    "- nn.Module模块\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取Mnist数据集\n",
    "- 会自动进行下载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import requests\n",
    "\n",
    "DATA_PATH = Path(\"data\")\n",
    "PATH = DATA_PATH / \"mnist\"\n",
    "\n",
    "PATH.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "URL = \"http://deeplearning.net/data/mnist/\"\n",
    "FILENAME = \"mnist.pkl.gz\"\n",
    "\n",
    "if not (PATH / FILENAME).exists():\n",
    "        content = requests.get(URL + FILENAME).content\n",
    "        (PATH / FILENAME).open(\"wb\").write(content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import gzip\n",
    "\n",
    "with gzip.open((PATH / FILENAME).as_posix(), \"rb\") as f:\n",
    "        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "784是mnist数据集每个样本的像素点个数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50000, 784)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiIA6CaSHqFbhcNhFRQU6NSpU7d93u/3y+/3p3oZAIA0lPLvE2publZ9fb3C4XCqXwoAkGE8XwldunRJX331Vezjuro6HT16VNnZ2crOzlZZWZlmzZqlcDisM2fO6K233lL//v01Y8aMpC4cAJD5PEfo0KFDmjx5cuzjm+/nlJaWat26dTp27Jg2btyoixcvKhwOa/Lkydq8ebMCgUDyVg0A6BK4gSmQIR555BHPM9OmTUvotSorKz3P+Hw+zzO7d+/2PDNlyhTPM7DBDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNoBO2traPM/07On9BzW3t7d7nvnxj3/seaa6utrzDO4fd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMCM9zsOArhvP/zhDz3PvPjii55nRo8e7XlGSuxmpIk4ceKE55m9e/emYCWwwpUQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gC3/D44497nlm4cKHnmZkzZ3qeCYVCnmcepOvXr3ueaWho8DzT0dHheQbpiyshAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzBF2kvkxp1z5sxJ6LUSuRnpd7/73YReK50dOnTI88w777zjeWbbtm2eZ9C1cCUEADBDhAAAZjxFqLy8XKNHj1YgEFBOTo6mT5+ukydPxu3jnFNZWZny8vLUp08fTZo0ScePH0/qogEAXYOnCNXU1GjBggXav3+/qqqq1N7eruLiYrW2tsb2WbVqldasWaOKigodPHhQoVBIU6ZMUUtLS9IXDwDIbJ6+MOHzzz+P+7iyslI5OTk6fPiwJkyYIOec3nvvPS1btiz2kyM3bNig3Nxcbdq0Sa+88kryVg4AyHj39Z5QJBKRJGVnZ0uS6urq1NjYqOLi4tg+fr9fEydOVG1t7W1/j7a2NkWj0bgNANA9JBwh55yWLFmi8ePHa9iwYZKkxsZGSVJubm7cvrm5ubHnblVeXq5gMBjb8vPzE10SACDDJByhhQsX6osvvtDHH3/c6Tmfzxf3sXOu02M3LV26VJFIJLbV19cnuiQAQIZJ6JtVFy1apG3btmnv3r0aOHBg7PGb31TY2NiocDgce7ypqanT1dFNfr9ffr8/kWUAADKcpysh55wWLlyoLVu2aPfu3SosLIx7vrCwUKFQSFVVVbHHrl27ppqaGhUVFSVnxQCALsPTldCCBQu0adMmffbZZwoEArH3eYLBoPr06SOfz6fFixdrxYoVGjx4sAYPHqwVK1aob9++evnll1PyBwAAZC5PEVq3bp0kadKkSXGPV1ZWau7cuZKkN998U1euXNH8+fN14cIFjRkzRrt27VIgEEjKggEAXYfPOeesF/FN0WhUwWDQehm4B9/2Pt+d/OAHP/A8U1FR4XnmiSee8DyT7g4cOOB55t13303otT777DPPMx0dHQm9FrquSCSirKysO+7DveMAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABgJqGfrIr0lZ2d7Xlm/fr1Cb3Wk08+6Xnme9/7XkKvlc5qa2s9z6xevdrzzM6dOz3PXLlyxfMM8CBxJQQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpg/ImDFjPM+88cYbnmeefvppzzOPPvqo55l0d/ny5YTm1q5d63lmxYoVnmdaW1s9zwBdEVdCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZbmD6gMyYMeOBzDxIJ06c8Dzzt7/9zfNMe3u755nVq1d7npGkixcvJjQHIDFcCQEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZnzOOWe9iG+KRqMKBoPWywAA3KdIJKKsrKw77sOVEADADBECAJjxFKHy8nKNHj1agUBAOTk5mj59uk6ePBm3z9y5c+Xz+eK2sWPHJnXRAICuwVOEampqtGDBAu3fv19VVVVqb29XcXGxWltb4/abOnWqGhoaYtuOHTuSumgAQNfg6Serfv7553EfV1ZWKicnR4cPH9aECRNij/v9foVCoeSsEADQZd3Xe0KRSESSlJ2dHfd4dXW1cnJyNGTIEM2bN09NTU3f+nu0tbUpGo3GbQCA7iHhL9F2zumFF17QhQsXtG/fvtjjmzdv1ne+8x0VFBSorq5Ov/71r9Xe3q7Dhw/L7/d3+n3Kysr09ttvJ/4nAACkpXv5Em25BM2fP98VFBS4+vr6O+537tw516tXL/eXv/zlts9fvXrVRSKR2FZfX+8ksbGxsbFl+BaJRO7aEk/vCd20aNEibdu2TXv37tXAgQPvuG84HFZBQYFOnTp12+f9fv9tr5AAAF2fpwg557Ro0SJ9+umnqq6uVmFh4V1nmpubVV9fr3A4nPAiAQBdk6cvTFiwYIH+/Oc/a9OmTQoEAmpsbFRjY6OuXLkiSbp06ZJef/11/fOf/9SZM2dUXV2tadOmqX///poxY0ZK/gAAgAzm5X0gfcvn/SorK51zzl2+fNkVFxe7AQMGuF69erlBgwa50tJSd/bs2Xt+jUgkYv55TDY2Nja2+9/u5T0hbmAKAEgJbmAKAEhrRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzaRch55z1EgAASXAvf5+nXYRaWlqslwAASIJ7+fvc59Ls0qOjo0Pnzp1TIBCQz+eLey4ajSo/P1/19fXKysoyWqE9jsMNHIcbOA43cBxuSIfj4JxTS0uL8vLy9NBDd77W6fmA1nTPHnroIQ0cOPCO+2RlZXXrk+wmjsMNHIcbOA43cBxusD4OwWDwnvZLu0/HAQC6DyIEADCTURHy+/1avny5/H6/9VJMcRxu4DjcwHG4geNwQ6Ydh7T7wgQAQPeRUVdCAICuhQgBAMwQIQCAGSIEADCTURH64IMPVFhYqIcfflgjR47Uvn37rJf0QJWVlcnn88VtoVDIelkpt3fvXk2bNk15eXny+XzaunVr3PPOOZWVlSkvL099+vTRpEmTdPz4cZvFptDdjsPcuXM7nR9jx461WWyKlJeXa/To0QoEAsrJydH06dN18uTJuH26w/lwL8chU86HjInQ5s2btXjxYi1btkxHjhzRM888o5KSEp09e9Z6aQ/U0KFD1dDQENuOHTtmvaSUa21t1YgRI1RRUXHb51etWqU1a9aooqJCBw8eVCgU0pQpU7rcfQjvdhwkaerUqXHnx44dOx7gClOvpqZGCxYs0P79+1VVVaX29nYVFxertbU1tk93OB/u5ThIGXI+uAzx9NNPu1dffTXusSeeeML96le/MlrRg7d8+XI3YsQI62WYkuQ+/fTT2McdHR0uFAq5lStXxh67evWqCwaD7ne/+53BCh+MW4+Dc86Vlpa6F154wWQ9VpqampwkV1NT45zrvufDrcfBucw5HzLiSujatWs6fPiwiouL4x4vLi5WbW2t0apsnDp1Snl5eSosLNRLL72k06dPWy/JVF1dnRobG+PODb/fr4kTJ3a7c0OSqqurlZOToyFDhmjevHlqamqyXlJKRSIRSVJ2drak7ns+3HocbsqE8yEjInT+/Hldv35dubm5cY/n5uaqsbHRaFUP3pgxY7Rx40bt3LlTH374oRobG1VUVKTm5mbrpZm5+d+/u58bklRSUqKPPvpIu3fv1urVq3Xw4EE9++yzamtrs15aSjjntGTJEo0fP17Dhg2T1D3Ph9sdBylzzoe0u4v2ndz6ox2cc50e68pKSkpivx4+fLjGjRunxx57TBs2bNCSJUsMV2avu58bkjR79uzYr4cNG6ZRo0apoKBA27dv18yZMw1XlhoLFy7UF198oX/84x+dnutO58O3HYdMOR8y4kqof//+6tGjR6d/yTQ1NXX6F0930q9fPw0fPlynTp2yXoqZm18dyLnRWTgcVkFBQZc8PxYtWqRt27Zpz549cT/6pbudD992HG4nXc+HjIhQ7969NXLkSFVVVcU9XlVVpaKiIqNV2Wtra9OXX36pcDhsvRQzhYWFCoVCcefGtWvXVFNT063PDUlqbm5WfX19lzo/nHNauHChtmzZot27d6uwsDDu+e5yPtztONxO2p4Phl8U4cknn3zievXq5f7whz+4EydOuMWLF7t+/fq5M2fOWC/tgXnttddcdXW1O336tNu/f797/vnnXSAQ6PLHoKWlxR05csQdOXLESXJr1qxxR44ccf/5z3+cc86tXLnSBYNBt2XLFnfs2DE3Z84cFw6HXTQaNV55ct3pOLS0tLjXXnvN1dbWurq6Ordnzx43btw49+ijj3ap4/CLX/zCBYNBV11d7RoaGmLb5cuXY/t0h/Phbschk86HjImQc869//77rqCgwPXu3ds99dRTcV+O2B3Mnj3bhcNh16tXL5eXl+dmzpzpjh8/br2slNuzZ4+T1GkrLS11zt34stzly5e7UCjk/H6/mzBhgjt27JjtolPgTsfh8uXLrri42A0YMMD16tXLDRo0yJWWlrqzZ89aLzupbvfnl+QqKytj+3SH8+FuxyGTzgd+lAMAwExGvCcEAOiaiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAz/wdVbyhNmNF0pQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot\n",
    "import numpy as np\n",
    "\n",
    "pyplot.imshow(x_train[0].reshape((28, 28)), cmap=\"gray\")\n",
    "print(x_train.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./img/4.png\" alt=\"FAO\" width=\"790\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./img/5.png\" alt=\"FAO\" width=\"790\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意数据需转换成tensor才能参与后续建模训练\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4,  ..., 8, 4, 8])\n",
      "torch.Size([50000, 784])\n",
      "tensor(0) tensor(9)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "# 将数据都转化成Torch数据\n",
    "x_train, y_train, x_valid, y_valid = map(\n",
    "    torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
    ")\n",
    "n, c = x_train.shape\n",
    "x_train, x_train.shape, y_train.min(), y_train.max()\n",
    "print(x_train, y_train)\n",
    "print(x_train.shape)\n",
    "print(y_train.min(), y_train.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### torch.nn.functional 很多层和函数在这里都会见到\n",
    "\n",
    "torch.nn.functional中有很多功能，后续会常用的。那什么时候使用nn.Module，什么时候使用nn.functional呢？一般情况下，如果模型有可学习的参数，最好用nn.Module，其他情况nn.functional相对更简单一些"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "# 使用交叉熵的方式计算loss函数\n",
    "loss_func = F.cross_entropy\n",
    "\n",
    "def model(xb):\n",
    "    return xb.mm(weights) + bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 784])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 64\n",
    "xb = x_train[0:bs]  # a mini-batch from x\n",
    "xb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(13.2326, grad_fn=<NllLossBackward0>)\n"
     ]
    }
   ],
   "source": [
    "bs = 64\n",
    "xb = x_train[0:bs]  # a mini-batch from x\n",
    "yb = y_train[0:bs]\n",
    "weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) \n",
    "bs = 64\n",
    "bias = torch.zeros(10, requires_grad=True)\n",
    "\n",
    "# \n",
    "print(loss_func(model(xb), yb))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建一个model来更简化代码"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数\n",
    "- 无需写反向传播函数，nn.Module能够利用autograd自动实现反向传播\n",
    "- Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "# 继承nn.Module\n",
    "class Mnist_NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden1 = nn.Linear(784, 128)\n",
    "        self.hidden2 = nn.Linear(128, 256)\n",
    "        self.out  = nn.Linear(256, 10)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.hidden1(x))\n",
    "        x = F.relu(self.hidden2(x))\n",
    "        x = self.out(x)\n",
    "        return x\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mnist_NN(\n",
      "  (hidden1): Linear(in_features=784, out_features=128, bias=True)\n",
      "  (hidden2): Linear(in_features=128, out_features=256, bias=True)\n",
      "  (out): Linear(in_features=256, out_features=10, bias=True)\n",
      "  (dropout): Dropout(p=0.5, inplace=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = Mnist_NN()\n",
    "print(net)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('hidden1.weight', Parameter containing:\n",
      "tensor([[-0.0039, -0.0199, -0.0283,  ..., -0.0031, -0.0331, -0.0093],\n",
      "        [-0.0146, -0.0272, -0.0232,  ..., -0.0113, -0.0247,  0.0084],\n",
      "        [-0.0046,  0.0346, -0.0145,  ..., -0.0149,  0.0246,  0.0147],\n",
      "        ...,\n",
      "        [ 0.0266,  0.0188,  0.0161,  ..., -0.0143, -0.0229, -0.0261],\n",
      "        [-0.0042,  0.0055,  0.0312,  ...,  0.0167,  0.0302,  0.0309],\n",
      "        [ 0.0122, -0.0295,  0.0123,  ..., -0.0056,  0.0222,  0.0074]],\n",
      "       requires_grad=True)), ('hidden1.bias', Parameter containing:\n",
      "tensor([ 0.0240,  0.0285, -0.0067,  0.0135, -0.0247, -0.0292,  0.0099,  0.0246,\n",
      "         0.0210, -0.0350, -0.0320, -0.0158, -0.0045, -0.0337, -0.0188, -0.0037,\n",
      "        -0.0043,  0.0259,  0.0161, -0.0013, -0.0021,  0.0104, -0.0137,  0.0194,\n",
      "         0.0354, -0.0340, -0.0315,  0.0351,  0.0195,  0.0327,  0.0273, -0.0353,\n",
      "         0.0158, -0.0264,  0.0225,  0.0239,  0.0166,  0.0327,  0.0344,  0.0068,\n",
      "        -0.0178,  0.0112, -0.0063,  0.0002,  0.0218, -0.0263, -0.0150,  0.0273,\n",
      "         0.0057, -0.0110,  0.0197,  0.0024,  0.0203, -0.0159, -0.0007,  0.0351,\n",
      "         0.0255,  0.0156,  0.0348, -0.0051,  0.0231,  0.0039,  0.0247,  0.0193,\n",
      "        -0.0096, -0.0085, -0.0090,  0.0158, -0.0107,  0.0275,  0.0342,  0.0117,\n",
      "        -0.0035, -0.0041,  0.0042,  0.0091,  0.0026, -0.0275,  0.0185, -0.0118,\n",
      "         0.0294,  0.0027,  0.0210,  0.0055, -0.0209,  0.0287,  0.0159, -0.0022,\n",
      "        -0.0153,  0.0116,  0.0274,  0.0344, -0.0144, -0.0165, -0.0279, -0.0167,\n",
      "        -0.0052, -0.0154,  0.0080,  0.0307,  0.0061, -0.0031, -0.0058,  0.0306,\n",
      "        -0.0059, -0.0258,  0.0022, -0.0181,  0.0200, -0.0296, -0.0059, -0.0205,\n",
      "         0.0089,  0.0101,  0.0024,  0.0238,  0.0146, -0.0121, -0.0172,  0.0122,\n",
      "        -0.0116,  0.0284,  0.0230, -0.0322, -0.0082,  0.0077, -0.0086, -0.0143],\n",
      "       requires_grad=True)), ('hidden2.weight', Parameter containing:\n",
      "tensor([[ 0.0778, -0.0772, -0.0126,  ..., -0.0830,  0.0002, -0.0519],\n",
      "        [-0.0591, -0.0795, -0.0287,  ..., -0.0062,  0.0520, -0.0329],\n",
      "        [-0.0540, -0.0397, -0.0759,  ..., -0.0597, -0.0539,  0.0437],\n",
      "        ...,\n",
      "        [-0.0344,  0.0584, -0.0618,  ..., -0.0364,  0.0273,  0.0387],\n",
      "        [-0.0682, -0.0209,  0.0263,  ...,  0.0718, -0.0805,  0.0379],\n",
      "        [-0.0645,  0.0447, -0.0329,  ..., -0.0212,  0.0311, -0.0268]],\n",
      "       requires_grad=True)), ('hidden2.bias', Parameter containing:\n",
      "tensor([ 0.0752, -0.0442, -0.0684, -0.0343,  0.0831, -0.0172, -0.0245, -0.0752,\n",
      "         0.0774,  0.0511,  0.0240, -0.0279, -0.0444,  0.0032,  0.0423, -0.0610,\n",
      "        -0.0729,  0.0641,  0.0497, -0.0097,  0.0532,  0.0606,  0.0038,  0.0876,\n",
      "        -0.0396,  0.0872, -0.0372,  0.0463,  0.0343, -0.0046,  0.0502,  0.0296,\n",
      "        -0.0718,  0.0300,  0.0025,  0.0613, -0.0553,  0.0591,  0.0291,  0.0286,\n",
      "        -0.0013,  0.0464, -0.0110,  0.0168,  0.0211,  0.0273, -0.0809,  0.0752,\n",
      "        -0.0011, -0.0198, -0.0193,  0.0590, -0.0710, -0.0044,  0.0206, -0.0596,\n",
      "        -0.0263,  0.0227, -0.0289, -0.0470, -0.0414, -0.0690, -0.0421, -0.0144,\n",
      "         0.0336,  0.0489,  0.0280,  0.0146,  0.0497,  0.0615,  0.0588,  0.0644,\n",
      "         0.0331, -0.0532, -0.0344,  0.0664, -0.0725,  0.0467,  0.0412,  0.0785,\n",
      "         0.0442,  0.0850, -0.0326,  0.0703,  0.0393, -0.0467, -0.0578, -0.0686,\n",
      "        -0.0427,  0.0154, -0.0440, -0.0244, -0.0249,  0.0358, -0.0154, -0.0577,\n",
      "        -0.0353,  0.0384, -0.0515,  0.0096, -0.0110,  0.0193, -0.0689, -0.0694,\n",
      "        -0.0192,  0.0807, -0.0871, -0.0665,  0.0572,  0.0184, -0.0793, -0.0161,\n",
      "         0.0441,  0.0570,  0.0307,  0.0131, -0.0862, -0.0419,  0.0501, -0.0485,\n",
      "        -0.0049,  0.0619, -0.0145, -0.0487, -0.0717, -0.0720,  0.0108,  0.0404,\n",
      "        -0.0166, -0.0371,  0.0870,  0.0812,  0.0243, -0.0090,  0.0869, -0.0764,\n",
      "         0.0445,  0.0027,  0.0201, -0.0649,  0.0319,  0.0863,  0.0559, -0.0169,\n",
      "        -0.0353, -0.0708,  0.0376,  0.0104, -0.0735, -0.0037, -0.0793,  0.0091,\n",
      "        -0.0527,  0.0600,  0.0400, -0.0131, -0.0511,  0.0191, -0.0538,  0.0290,\n",
      "         0.0586, -0.0504,  0.0755, -0.0392, -0.0587,  0.0435,  0.0700, -0.0327,\n",
      "         0.0077,  0.0093,  0.0672, -0.0544, -0.0072, -0.0076,  0.0550, -0.0255,\n",
      "        -0.0373,  0.0591, -0.0495, -0.0843, -0.0441,  0.0556, -0.0023, -0.0025,\n",
      "         0.0034,  0.0858,  0.0499, -0.0590, -0.0049,  0.0161,  0.0862,  0.0013,\n",
      "        -0.0237,  0.0681,  0.0534, -0.0094,  0.0549, -0.0672,  0.0507, -0.0249,\n",
      "        -0.0872,  0.0268,  0.0628,  0.0107,  0.0154, -0.0116,  0.0565, -0.0339,\n",
      "        -0.0724, -0.0448,  0.0469, -0.0777,  0.0615, -0.0026, -0.0679, -0.0737,\n",
      "         0.0849, -0.0644,  0.0345, -0.0832,  0.0328,  0.0274, -0.0546, -0.0719,\n",
      "        -0.0741, -0.0879,  0.0062, -0.0580, -0.0098,  0.0049,  0.0593, -0.0088,\n",
      "         0.0689,  0.0746, -0.0726, -0.0607,  0.0172, -0.0627,  0.0574, -0.0247,\n",
      "         0.0518, -0.0012, -0.0552,  0.0768, -0.0008, -0.0402,  0.0476,  0.0079,\n",
      "         0.0049,  0.0212,  0.0607,  0.0465, -0.0583,  0.0065,  0.0152, -0.0879],\n",
      "       requires_grad=True)), ('out.weight', Parameter containing:\n",
      "tensor([[-0.0354, -0.0021, -0.0492,  ...,  0.0511,  0.0474,  0.0349],\n",
      "        [-0.0051, -0.0075,  0.0532,  ...,  0.0311,  0.0291, -0.0593],\n",
      "        [-0.0123, -0.0156,  0.0115,  ...,  0.0350,  0.0467,  0.0452],\n",
      "        ...,\n",
      "        [ 0.0367,  0.0604,  0.0029,  ..., -0.0470,  0.0550, -0.0243],\n",
      "        [ 0.0449, -0.0520, -0.0368,  ..., -0.0351,  0.0086, -0.0052],\n",
      "        [ 0.0040, -0.0037, -0.0430,  ...,  0.0590,  0.0021, -0.0248]],\n",
      "       requires_grad=True)), ('out.bias', Parameter containing:\n",
      "tensor([ 0.0521, -0.0248,  0.0278, -0.0423,  0.0194,  0.0251,  0.0509,  0.0491,\n",
      "         0.0334,  0.0505], requires_grad=True))]\n"
     ]
    }
   ],
   "source": [
    "print(list(net.named_parameters()))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以打印我们定义好名字里的权重和偏置项"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hidden1.weight Parameter containing:\n",
      "tensor([[-0.0039, -0.0199, -0.0283,  ..., -0.0031, -0.0331, -0.0093],\n",
      "        [-0.0146, -0.0272, -0.0232,  ..., -0.0113, -0.0247,  0.0084],\n",
      "        [-0.0046,  0.0346, -0.0145,  ..., -0.0149,  0.0246,  0.0147],\n",
      "        ...,\n",
      "        [ 0.0266,  0.0188,  0.0161,  ..., -0.0143, -0.0229, -0.0261],\n",
      "        [-0.0042,  0.0055,  0.0312,  ...,  0.0167,  0.0302,  0.0309],\n",
      "        [ 0.0122, -0.0295,  0.0123,  ..., -0.0056,  0.0222,  0.0074]],\n",
      "       requires_grad=True) torch.Size([128, 784])\n",
      "hidden1.bias Parameter containing:\n",
      "tensor([ 0.0240,  0.0285, -0.0067,  0.0135, -0.0247, -0.0292,  0.0099,  0.0246,\n",
      "         0.0210, -0.0350, -0.0320, -0.0158, -0.0045, -0.0337, -0.0188, -0.0037,\n",
      "        -0.0043,  0.0259,  0.0161, -0.0013, -0.0021,  0.0104, -0.0137,  0.0194,\n",
      "         0.0354, -0.0340, -0.0315,  0.0351,  0.0195,  0.0327,  0.0273, -0.0353,\n",
      "         0.0158, -0.0264,  0.0225,  0.0239,  0.0166,  0.0327,  0.0344,  0.0068,\n",
      "        -0.0178,  0.0112, -0.0063,  0.0002,  0.0218, -0.0263, -0.0150,  0.0273,\n",
      "         0.0057, -0.0110,  0.0197,  0.0024,  0.0203, -0.0159, -0.0007,  0.0351,\n",
      "         0.0255,  0.0156,  0.0348, -0.0051,  0.0231,  0.0039,  0.0247,  0.0193,\n",
      "        -0.0096, -0.0085, -0.0090,  0.0158, -0.0107,  0.0275,  0.0342,  0.0117,\n",
      "        -0.0035, -0.0041,  0.0042,  0.0091,  0.0026, -0.0275,  0.0185, -0.0118,\n",
      "         0.0294,  0.0027,  0.0210,  0.0055, -0.0209,  0.0287,  0.0159, -0.0022,\n",
      "        -0.0153,  0.0116,  0.0274,  0.0344, -0.0144, -0.0165, -0.0279, -0.0167,\n",
      "        -0.0052, -0.0154,  0.0080,  0.0307,  0.0061, -0.0031, -0.0058,  0.0306,\n",
      "        -0.0059, -0.0258,  0.0022, -0.0181,  0.0200, -0.0296, -0.0059, -0.0205,\n",
      "         0.0089,  0.0101,  0.0024,  0.0238,  0.0146, -0.0121, -0.0172,  0.0122,\n",
      "        -0.0116,  0.0284,  0.0230, -0.0322, -0.0082,  0.0077, -0.0086, -0.0143],\n",
      "       requires_grad=True) torch.Size([128])\n",
      "hidden2.weight Parameter containing:\n",
      "tensor([[ 0.0778, -0.0772, -0.0126,  ..., -0.0830,  0.0002, -0.0519],\n",
      "        [-0.0591, -0.0795, -0.0287,  ..., -0.0062,  0.0520, -0.0329],\n",
      "        [-0.0540, -0.0397, -0.0759,  ..., -0.0597, -0.0539,  0.0437],\n",
      "        ...,\n",
      "        [-0.0344,  0.0584, -0.0618,  ..., -0.0364,  0.0273,  0.0387],\n",
      "        [-0.0682, -0.0209,  0.0263,  ...,  0.0718, -0.0805,  0.0379],\n",
      "        [-0.0645,  0.0447, -0.0329,  ..., -0.0212,  0.0311, -0.0268]],\n",
      "       requires_grad=True) torch.Size([256, 128])\n",
      "hidden2.bias Parameter containing:\n",
      "tensor([ 0.0752, -0.0442, -0.0684, -0.0343,  0.0831, -0.0172, -0.0245, -0.0752,\n",
      "         0.0774,  0.0511,  0.0240, -0.0279, -0.0444,  0.0032,  0.0423, -0.0610,\n",
      "        -0.0729,  0.0641,  0.0497, -0.0097,  0.0532,  0.0606,  0.0038,  0.0876,\n",
      "        -0.0396,  0.0872, -0.0372,  0.0463,  0.0343, -0.0046,  0.0502,  0.0296,\n",
      "        -0.0718,  0.0300,  0.0025,  0.0613, -0.0553,  0.0591,  0.0291,  0.0286,\n",
      "        -0.0013,  0.0464, -0.0110,  0.0168,  0.0211,  0.0273, -0.0809,  0.0752,\n",
      "        -0.0011, -0.0198, -0.0193,  0.0590, -0.0710, -0.0044,  0.0206, -0.0596,\n",
      "        -0.0263,  0.0227, -0.0289, -0.0470, -0.0414, -0.0690, -0.0421, -0.0144,\n",
      "         0.0336,  0.0489,  0.0280,  0.0146,  0.0497,  0.0615,  0.0588,  0.0644,\n",
      "         0.0331, -0.0532, -0.0344,  0.0664, -0.0725,  0.0467,  0.0412,  0.0785,\n",
      "         0.0442,  0.0850, -0.0326,  0.0703,  0.0393, -0.0467, -0.0578, -0.0686,\n",
      "        -0.0427,  0.0154, -0.0440, -0.0244, -0.0249,  0.0358, -0.0154, -0.0577,\n",
      "        -0.0353,  0.0384, -0.0515,  0.0096, -0.0110,  0.0193, -0.0689, -0.0694,\n",
      "        -0.0192,  0.0807, -0.0871, -0.0665,  0.0572,  0.0184, -0.0793, -0.0161,\n",
      "         0.0441,  0.0570,  0.0307,  0.0131, -0.0862, -0.0419,  0.0501, -0.0485,\n",
      "        -0.0049,  0.0619, -0.0145, -0.0487, -0.0717, -0.0720,  0.0108,  0.0404,\n",
      "        -0.0166, -0.0371,  0.0870,  0.0812,  0.0243, -0.0090,  0.0869, -0.0764,\n",
      "         0.0445,  0.0027,  0.0201, -0.0649,  0.0319,  0.0863,  0.0559, -0.0169,\n",
      "        -0.0353, -0.0708,  0.0376,  0.0104, -0.0735, -0.0037, -0.0793,  0.0091,\n",
      "        -0.0527,  0.0600,  0.0400, -0.0131, -0.0511,  0.0191, -0.0538,  0.0290,\n",
      "         0.0586, -0.0504,  0.0755, -0.0392, -0.0587,  0.0435,  0.0700, -0.0327,\n",
      "         0.0077,  0.0093,  0.0672, -0.0544, -0.0072, -0.0076,  0.0550, -0.0255,\n",
      "        -0.0373,  0.0591, -0.0495, -0.0843, -0.0441,  0.0556, -0.0023, -0.0025,\n",
      "         0.0034,  0.0858,  0.0499, -0.0590, -0.0049,  0.0161,  0.0862,  0.0013,\n",
      "        -0.0237,  0.0681,  0.0534, -0.0094,  0.0549, -0.0672,  0.0507, -0.0249,\n",
      "        -0.0872,  0.0268,  0.0628,  0.0107,  0.0154, -0.0116,  0.0565, -0.0339,\n",
      "        -0.0724, -0.0448,  0.0469, -0.0777,  0.0615, -0.0026, -0.0679, -0.0737,\n",
      "         0.0849, -0.0644,  0.0345, -0.0832,  0.0328,  0.0274, -0.0546, -0.0719,\n",
      "        -0.0741, -0.0879,  0.0062, -0.0580, -0.0098,  0.0049,  0.0593, -0.0088,\n",
      "         0.0689,  0.0746, -0.0726, -0.0607,  0.0172, -0.0627,  0.0574, -0.0247,\n",
      "         0.0518, -0.0012, -0.0552,  0.0768, -0.0008, -0.0402,  0.0476,  0.0079,\n",
      "         0.0049,  0.0212,  0.0607,  0.0465, -0.0583,  0.0065,  0.0152, -0.0879],\n",
      "       requires_grad=True) torch.Size([256])\n",
      "out.weight Parameter containing:\n",
      "tensor([[-0.0354, -0.0021, -0.0492,  ...,  0.0511,  0.0474,  0.0349],\n",
      "        [-0.0051, -0.0075,  0.0532,  ...,  0.0311,  0.0291, -0.0593],\n",
      "        [-0.0123, -0.0156,  0.0115,  ...,  0.0350,  0.0467,  0.0452],\n",
      "        ...,\n",
      "        [ 0.0367,  0.0604,  0.0029,  ..., -0.0470,  0.0550, -0.0243],\n",
      "        [ 0.0449, -0.0520, -0.0368,  ..., -0.0351,  0.0086, -0.0052],\n",
      "        [ 0.0040, -0.0037, -0.0430,  ...,  0.0590,  0.0021, -0.0248]],\n",
      "       requires_grad=True) torch.Size([10, 256])\n",
      "out.bias Parameter containing:\n",
      "tensor([ 0.0521, -0.0248,  0.0278, -0.0423,  0.0194,  0.0251,  0.0509,  0.0491,\n",
      "         0.0334,  0.0505], requires_grad=True) torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "for name, parameter in net.named_parameters():\n",
    "    print(name, parameter,parameter.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用TensorDataset和DataLoader来简化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_ds = TensorDataset(x_train, y_train)\n",
    "train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)\n",
    "\n",
    "valid_ds = TensorDataset(x_valid, y_valid)\n",
    "valid_dl = DataLoader(valid_ds, batch_size=bs * 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(train_ds, valid_ds, bs):\n",
    "    return (\n",
    "        DataLoader(train_ds, batch_size=bs, shuffle=True),\n",
    "        DataLoader(valid_ds, batch_size=bs * 2),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 一般在训练模型时加上model.train()，这样会正常使用Batch Normalization和 Dropout\n",
    "- 测试的时候一般选择model.eval()，这样就不会使用Batch Normalization和 Dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def fit(steps, model, loss_func, opt, train_dl, valid_dl):\n",
    "    for step in range(steps):\n",
    "        model.train()\n",
    "        for xb, yb in train_dl:\n",
    "            loss_batch(model, loss_func, xb, yb, opt)\n",
    "\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            losses, nums = zip(\n",
    "                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]\n",
    "            )\n",
    "        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)\n",
    "        print('当前step:'+str(step), '验证集损失：'+str(val_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "def get_model():\n",
    "    model = Mnist_NN()\n",
    "    return model, optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_batch(model, loss_func, xb, yb, opt=None):\n",
    "    loss = loss_func(model(xb), yb)\n",
    "\n",
    "    if opt is not None:\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()\n",
    "\n",
    "    return loss.item(), len(xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三行搞定！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前step:0 验证集损失：0.14866696909070015\n",
      "当前step:1 验证集损失：0.10762357178330421\n",
      "当前step:2 验证集损失：0.09148775678351521\n",
      "当前step:3 验证集损失：0.09239266448542476\n",
      "当前step:4 验证集损失：0.09119845449198037\n",
      "当前step:5 验证集损失：0.09368460708372295\n",
      "当前step:6 验证集损失：0.09543108020629734\n",
      "当前step:7 验证集损失：0.12964335346361622\n",
      "当前step:8 验证集损失：0.1064905257988372\n",
      "当前step:9 验证集损失：0.0952181197460508\n",
      "当前step:10 验证集损失：0.09992968667526002\n",
      "当前step:11 验证集损失：0.1093115426942466\n",
      "当前step:12 验证集损失：0.1053862185144244\n",
      "当前step:13 验证集损失：0.13352388969183956\n",
      "当前step:14 验证集损失：0.12262713425904731\n",
      "当前step:15 验证集损失：0.12021054504886251\n",
      "当前step:16 验证集损失：0.13394502818775944\n",
      "当前step:17 验证集损失：0.1302742688085876\n",
      "当前step:18 验证集损失：0.13685556791070835\n",
      "当前step:19 验证集损失：0.13136188031951315\n",
      "当前step:20 验证集损失：0.13773769891970727\n",
      "当前step:21 验证集损失：0.14880260401972592\n",
      "当前step:22 验证集损失：0.13825004334147087\n",
      "当前step:23 验证集损失：0.13655680389303015\n",
      "当前step:24 验证集损失：0.16287535354070132\n"
     ]
    }
   ],
   "source": [
    "train_dl, valid_dl = get_data(train_ds, valid_ds, bs)\n",
    "model, opt = get_model()\n",
    "fit(25, model, loss_func, opt, train_dl, valid_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
